package table

import (
	"bytes"
	"fmt"
	"os"
	"path/filepath"
	"slices"
	"sort"
	"strings"
	"sync"

	"github.com/fatih/color"
	"github.com/samber/lo"
	"github.com/xlab/treeprint"

	"github.com/aquasecurity/table"
	"github.com/aquasecurity/tml"
	dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
	ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
	"github.com/aquasecurity/trivy/pkg/log"
	"github.com/aquasecurity/trivy/pkg/set"
	"github.com/aquasecurity/trivy/pkg/types"
	"github.com/aquasecurity/trivy/pkg/version/doc"
)

const (
	vexNotice = `
For OSS Maintainers: VEX Notice
--------------------------------
If you're an OSS maintainer and Trivy has detected vulnerabilities in your project that you believe are not actually exploitable, consider issuing a VEX (Vulnerability Exploitability eXchange) statement.
VEX allows you to communicate the actual status of vulnerabilities in your project, improving security transparency and reducing false positives for your users.
Learn more and start using VEX: %s

To disable this notice, set the TRIVY_DISABLE_VEX_NOTICE environment variable.

`
	envDisableNotice = "TRIVY_DISABLE_VEX_NOTICE"
)

var (
	showVEXNoticeOnce  = &sync.Once{}
	showSuppressedOnce = sync.OnceFunc(func() {
		log.Info(`Some vulnerabilities have been ignored/suppressed. Use the "--show-suppressed" flag to display them.`)
	})
)

type vulnerabilityRenderer struct {
	w              *bytes.Buffer
	isTerminal     bool
	tree           bool // Show dependency tree
	showSuppressed bool // Show suppressed vulnerabilities
	severities     []dbTypes.Severity
	once           *sync.Once
}

func NewVulnerabilityRenderer(buf *bytes.Buffer, isTerminal, tree, suppressed bool, severities []dbTypes.Severity) *vulnerabilityRenderer {
	if !isTerminal {
		tml.DisableFormatting()
	}
	return &vulnerabilityRenderer{
		w:              buf,
		isTerminal:     isTerminal,
		tree:           tree,
		showSuppressed: suppressed,
		severities:     severities,
		once:           new(sync.Once),
	}
}

func (r *vulnerabilityRenderer) Render(result types.Result) {
	// There are 3 cases when we show the vulnerability table (or only target and `Total: 0...`):
	// When Result contains vulnerabilities;
	// When we show non-empty `Suppressed Vulnerabilities` table.
	if len(result.Vulnerabilities) > 0 || (r.showSuppressed && len(result.ModifiedFindings) > 0) {
		r.renderDetectedVulnerabilities(result)

		if r.tree {
			r.renderDependencyTree(result)
		}
	}

	if r.showSuppressed {
		r.renderModifiedVulnerabilities(result.ModifiedFindings)
	} else if len(result.ModifiedFindings) > 0 {
		showSuppressedOnce()
	}
}

func (r *vulnerabilityRenderer) renderDetectedVulnerabilities(result types.Result) {
	// Show VEX notice only on CI
	showVEXNoticeOnce.Do(func() {
		if os.Getenv(envDisableNotice) != "" || os.Getenv("CI") == "" {
			return
		}
		_, _ = color.New(color.FgCyan).Fprintf(r.w, vexNotice, doc.URL("guide/supply-chain/vex/repo", "publishing-vex-documents"))
	})

	tw := newTableWriter(r.w, r.isTerminal)
	r.setHeaders(tw, result.Vulnerabilities)
	r.setVulnerabilityRows(tw, result.Vulnerabilities)

	severityCount := r.countSeverities(result.Vulnerabilities)
	total, summaries := summarize(r.severities, severityCount)

	target := result.Target
	if result.Class == types.ClassLangPkg {
		target += fmt.Sprintf(" (%s)", result.Type)
	}
	RenderTarget(r.w, target, r.isTerminal)
	r.printf("Total: %d (%s)\n\n", total, strings.Join(summaries, ", "))

	tw.Render()
}

func (r *vulnerabilityRenderer) setHeaders(tw *table.Table, vulns []types.DetectedVulnerability) {
	if len(vulns) == 0 {
		return
	}
	header := []string{
		"Library",
		"Vulnerability",
		"Severity",
		"Status",
		"Installed Version",
		"Fixed Version",
		"Title",
	}
	tw.SetHeaders(header...)
}

func (r *vulnerabilityRenderer) setVulnerabilityRows(tw *table.Table, vulns []types.DetectedVulnerability) {
	for _, v := range vulns {
		lib := v.PkgName
		if v.PkgPath != "" {
			pkgPath := rootJarFromPath(v.PkgPath)
			fileName := filepath.Base(pkgPath)
			lib = fmt.Sprintf("%s (%s)", v.PkgName, fileName)
			r.once.Do(func() {
				log.Info("Table result includes only package filenames. Use '--format json' option to get the full path to the package file.")
			})
		}

		title := v.Title
		if title == "" {
			title = v.Description
		}
		splitTitle := strings.Split(title, " ")
		if len(splitTitle) >= 12 {
			title = strings.Join(splitTitle[:12], " ") + "..."
		}

		if v.PrimaryURL != "" {
			if r.isTerminal {
				title = tml.Sprintf("%s\n<blue>%s</blue>", title, v.PrimaryURL)
			} else {
				title = fmt.Sprintf("%s\n%s", title, v.PrimaryURL)
			}
		}

		var row []string
		if r.isTerminal {
			row = []string{
				lib,
				v.VulnerabilityID,
				ColorizeSeverity(v.Severity, v.Severity),
				v.Status.String(),
				v.InstalledVersion,
				v.FixedVersion,
				strings.TrimSpace(title),
			}
		} else {
			row = []string{
				lib,
				v.VulnerabilityID,
				v.Severity,
				v.Status.String(),
				v.InstalledVersion,
				v.FixedVersion,
				strings.TrimSpace(title),
			}
		}

		tw.AddRow(row...)
	}
}

func (r *vulnerabilityRenderer) countSeverities(vulns []types.DetectedVulnerability) map[string]int {
	severityCount := make(map[string]int)
	for _, v := range vulns {
		severityCount[v.Severity]++
	}
	return severityCount
}

func (r *vulnerabilityRenderer) renderModifiedVulnerabilities(modifiedFindings []types.ModifiedFinding) {
	tw := newTableWriter(r.w, r.isTerminal)
	header := []string{
		"Library",
		"Vulnerability",
		"Severity",
		"Status",
		"Statement",
		"Source",
	}
	tw.SetHeaders(header...)

	var total int
	for _, m := range modifiedFindings {
		if m.Type != types.FindingTypeVulnerability {
			continue
		}
		vuln := m.Finding.(types.DetectedVulnerability)
		total++

		stmt := lo.Ternary(m.Statement != "", m.Statement, "N/A")
		tw.AddRow(vuln.PkgName, vuln.VulnerabilityID, vuln.Severity, string(m.Status), stmt, m.Source)
	}

	if total == 0 {
		return
	}

	title := fmt.Sprintf("Suppressed Vulnerabilities (Total: %d)", total)
	if r.isTerminal {
		// nolint
		_ = tml.Fprintf(r.w, "\n<underline>%s</underline>\n\n", title)
	} else {
		_, _ = fmt.Fprintf(r.w, "\n%s\n", title)
		_, _ = fmt.Fprintf(r.w, "%s\n", strings.Repeat("=", len(title)))
	}

	tw.Render()
}

func (r *vulnerabilityRenderer) renderDependencyTree(result types.Result) {
	// Get parents of each dependency
	parents := ftypes.Packages(result.Packages).ParentDeps()
	if len(parents) == 0 {
		return
	}
	ancestors := traverseAncestors(result.Packages, parents)

	root := treeprint.NewWithRoot(fmt.Sprintf(`
Dependency Origin Tree (Reversed)
=================================
%s`, result.Target))

	// This count is next to the package ID.
	// e.g. node-fetch@1.7.3 (MEDIUM: 2, HIGH: 1, CRITICAL: 3)
	pkgSeverityCount := make(map[string]map[string]int)
	for _, vuln := range result.Vulnerabilities {
		cnts, ok := pkgSeverityCount[vuln.PkgID]
		if !ok {
			cnts = make(map[string]int)
		}

		cnts[vuln.Severity]++
		pkgSeverityCount[vuln.PkgID] = cnts
	}

	// Extract vulnerable packages
	vulnPkgs := lo.Filter(result.Packages, func(pkg ftypes.Package, _ int) bool {
		return lo.ContainsBy(result.Vulnerabilities, func(vuln types.DetectedVulnerability) bool {
			return pkg.ID == vuln.PkgID
		})
	})

	// Render tree
	for _, vulnPkg := range vulnPkgs {
		_, summaries := summarize(r.severities, pkgSeverityCount[vulnPkg.ID])
		topLvlID := tml.Sprintf("<red>%s, (%s)</red>", vulnPkg.ID, strings.Join(summaries, ", "))

		branch := root.AddBranch(topLvlID)
		addParents(branch, vulnPkg, parents, ancestors, set.New(vulnPkg.ID), 1)

	}
	r.printf(root.String())
}

func (r *vulnerabilityRenderer) printf(format string, args ...any) {
	// nolint
	_ = tml.Fprintf(r.w, format, args...)
}

func addParents(topItem treeprint.Tree, pkg ftypes.Package, parentMap map[string]ftypes.Packages, ancestors map[string][]string,
	seen set.Set[string], depth int) {
	if pkg.Relationship == ftypes.RelationshipDirect {
		return
	}

	roots := set.New[string]()
	for _, parent := range parentMap[pkg.ID] {
		if seen.Contains(parent.ID) {
			continue
		}
		seen.Append(parent.ID) // to avoid infinite loops

		if depth == 1 && parent.Relationship == ftypes.RelationshipDirect {
			topItem.AddBranch(parent.ID)
		} else {
			// We omit intermediate dependencies and show only direct dependencies
			// as this could make the dependency tree huge.
			for _, ancestor := range ancestors[parent.ID] {
				roots.Append(ancestor)
			}
		}
	}

	// Omitted
	rootIDs := roots.Difference(seen).Items()
	sort.Strings(rootIDs)
	if len(rootIDs) > 0 {
		branch := topItem.AddBranch("...(omitted)...")
		for _, rootID := range rootIDs {
			branch.AddBranch(rootID)
		}
	}
}

func traverseAncestors(pkgs []ftypes.Package, parentMap map[string]ftypes.Packages) map[string][]string {
	ancestors := make(map[string][]string)
	for _, pkg := range pkgs {
		ancestors[pkg.ID] = findAncestor(pkg.ID, parentMap, set.New[string]())
	}
	return ancestors
}

func findAncestor(pkgID string, parentMap map[string]ftypes.Packages, seen set.Set[string]) []string {
	ancestors := set.New[string]()
	seen.Append(pkgID)
	for _, parent := range parentMap[pkgID] {
		if seen.Contains(parent.ID) {
			continue
		}
		switch {
		case parent.Relationship == ftypes.RelationshipDirect:
			ancestors.Append(parent.ID)
		case len(parentMap[parent.ID]) == 0:
			// Some package managers, such as "package-lock.json" v1, can retrieve package dependencies but not relationships.
			// We try to guess direct dependencies in this case. A dependency with no parents must be a direct dependency.
			//
			// e.g.
			//   -> styled-components
			//     -> fbjs
			//       -> isomorphic-fetch
			//         -> node-fetch
			//
			// Even if `styled-components` is not marked as a direct dependency, it must be a direct dependency
			// as it has no parents. Note that it doesn't mean `fbjs` is an indirect dependency.
			ancestors.Append(parent.ID)
		default:
			for _, a := range findAncestor(parent.ID, parentMap, seen) {
				ancestors.Append(a)
			}
		}
	}
	return ancestors.Items()
}

var jarExtensions = []string{
	".jar",
	".war",
	".par",
	".ear",
}

// rootJarFromPath returns path to root jar.
// For other languages return unchanged path
func rootJarFromPath(path string) string {
	// File paths are always forward-slashed in Trivy
	paths := strings.Split(path, "/")
	for i, p := range paths {
		if slices.Contains(jarExtensions, filepath.Ext(p)) {
			return strings.Join(paths[:i+1], "/")
		}
	}
	return path
}
